-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache #11277
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
188fc99
to
b63906b
Compare
b63906b
to
92a0c4e
Compare
This pull request has merge conflicts that must be resolved before it can be |
92a0c4e
to
38a83fe
Compare
Co-authored-by: Jiangfei Duan <[email protected]> Signed-off-by: Liangfu Chen <[email protected]>
8d1618e
to
c2af356
Compare
@robertgshaw2-neuralmagic @WoosukKwon PTAL |
return o | ||
|
||
|
||
def flash_attn_varlen_nkifunc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we could align this to the API of the unified flash attention funciton for V1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. in v1/attention/backends/flash_attention
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aligning to this interface will reduce special cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question!
There are three challenges that are blocking the alignment of the interface:
- calling
reshape_and_cache_flash
beforeflash_attn_varlen_func
can be inefficient read-after-write, due to synchronous process on a heterogeneous architecture. On neuron stack, it's better to merge cached tokens with new tokens in an asynchronous fashion. Therefore, we pass both cached KV and active KV to the flash-attention function call, so that writing to HBM won't block compuation. - Since slicing
num_actual_tokens
would not reduce computation or bandwidth utilization, I think it's better to slice the logits in the last attention layer. - We find it more efficient to compuate attention mask before each of the layer. Therefore, instead of passing sequence lengths and query lengths, we pass the pre-computed attention mask directly to the flash-attention kernel.
I'm actively looking for ideas that can help close the gap between the interfaces, without degrading performance of the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertgshaw2-redhat do you think the interface misalignment can be a blocking issue for merging?
Nice work! |
@simon-mo @WoosukKwon do you have any other concerns / comments ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very isolated change. I'm okay with merging this
None), "continuous_batching_mask does not support logit_bias!" | ||
|
||
# mask are used to only apply computation to the lower half of the matrix, | ||
# which reduce the arthimetic intensity by half |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a typo arthimetic
breaks the CI.
PR to fix the failing precommit #12497 |
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]> Signed-off-by: Isotr0py <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
…vllm-project#11277) Signed-off-by: Liangfu Chen <[email protected]> Co-authored-by: Jiangfei Duan <[email protected]>
Summary
FIX #11152
This PR introduce a NKI-based kernel that brings the support for chunked-prefill with flash-attention.
Co-authored-by: Jiangfei Duan [email protected]